# %%
# This is the file for all the function that is needed in data analysis for chiral coupler

# imports
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import ticker
from matplotlib import colors
from scipy.integrate import solve_ivp, quad
from scipy.interpolate import interp1d
from scipy.fft import fft, ifft
from scipy.optimize import fsolve, broyden1, root
from functools import partial
import os
import numpy as np
import qutip as qt
from matplotlib import pyplot as plt
import seaborn as sns

from lmfit import Parameters, fit_report, minimize

from labcore.data.datadict_storage import datadict_from_hdf5
from qcui_analysis.fitfuncs.resonators import fit_and_plot_resonator_response, HangerResponseBruno, ReflectionResponse


# %%
# functions for setup figures
def setup_figure(sns_style='whitegrid', rcparams={}):
    # some sensible defaults for sizing, those are for a typical print-plot
    mpl.rcParams['figure.constrained_layout.use'] = True
    mpl.rcParams['figure.dpi'] = 150  # 300
    mpl.rcParams['font.family'] = 'Arial', 'Helvetica'
    mpl.rcParams['font.size'] = 7  # 6
    mpl.rcParams['lines.markersize'] = 3
    mpl.rcParams['lines.linewidth'] = 1.5
    mpl.rcParams['axes.linewidth'] = 0.5
    mpl.rcParams['grid.linewidth'] = 0.5
    mpl.rcParams['legend.fontsize'] = 5
    mpl.rcParams['legend.frameon'] = False
    mpl.rcParams['xtick.major.width'] = 0.5
    mpl.rcParams['ytick.major.width'] = 0.5
    mpl.rcParams['xtick.major.size'] = 2
    mpl.rcParams['ytick.major.size'] = 2
    mpl.rcParams['mathtext.fontset'] = 'dejavusans'


    # tick settings
    mpl.rcParams['xtick.labelcolor'] = 'black'
    mpl.rcParams['xtick.major.size'] = 1.5
    mpl.rcParams['xtick.direction'] = 'out'

    mpl.rcParams['ytick.labelcolor'] = 'black'
    mpl.rcParams['ytick.major.size'] = 1.5
    mpl.rcParams['ytick.direction'] = 'out'


    sns.set_style(sns_style)
    mpl.rcParams['axes.labelcolor'] = 'black'
    mpl.rcParams.update(rcparams)

def format_axes(
        ax,
        xlabel=None,
        ylabel=None,
        xlim=None,
        ylim=None,
        xticks=3,
        yticks=3
):
    if isinstance(xticks, list):
        ax.xaxis.set_major_locator(ticker.FixedLocator(xticks))
        if xlim is not None:
            ax.set_xlim(xlim)
    elif xlim is not None:
        ax.xaxis.set_major_locator(ticker.LinearLocator(xticks))
        ax.set_xlim(xlim)
    else:
        ax.xaxis.set_major_locator(ticker.MaxNLocator(xticks))

    if isinstance(yticks, list):
        ax.yaxis.set_major_locator(ticker.FixedLocator(yticks))
        if ylim is not None:
            ax.set_xlim(ylim)
    elif ylim is not None:
        ax.yaxis.set_major_locator(ticker.LinearLocator(yticks))
        ax.set_ylim(ylim)
    else:
        ax.yaxis.set_major_locator(ticker.MaxNLocator(yticks))

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)


def add_legend_outside(
        ax,
        handles,
        labels,
        anchor_point=(1, 1),
        legend_ref_point='lower right'
):
    ax.legend(handles, labels, bbox_to_anchor=anchor_point, borderpad=0, loc=legend_ref_point)


# %%

# functions for load data
def load_pump_phase_sweep_data(fn):
    data = datadict_from_hdf5(fn)
    frequency = data.data_vals('frequency')
    try:
        phase = data.data_vals('phase')
    except:
        phase = data.data_vals('QICK_phase')
    response = data.data_vals('trace')
    return frequency, phase, response


def load_sa_data(fn):
    data = datadict_from_hdf5(fn)
    frequency = data.data_vals('frequency')
    response = data.data_vals('trace')
    try:
        phase = data.data_vals('phase')
    except:
        phase = data.data_vals('QICK_phase')
    return frequency, phase, response


def load_single_sa_data(fn):
    data = datadict_from_hdf5(fn)
    frequency = data.data_vals('frequency')
    response = data.data_vals('trace')
    return frequency, response


def load_single_vna_data(fn):
    data = datadict_from_hdf5(fn)
    frequency = data.data_vals('frequency')
    response = data.data_vals('trace')
    return frequency, response


def load_single_pump_data(fn):
    data = datadict_from_hdf5(fn)
    frequency = data.data_vals('frequency')
    pump_freq = data.data_vals('pump_freq')
    response = data.data_vals('trace')
    return frequency, pump_freq, response


def load_flux_sweep_data(fn):
    data = datadict_from_hdf5(fn)
    frequency = data.data_vals('frequency')
    trace = data.data_vals('trace')
    try:
        current = data.data_vals('yoko_current')
    except:
        current = data.data_vals('koko_current')
    return frequency, current, trace


# %%
# functions for smooth data
def savitzky_golay(y, window_size, order, deriv=0, rate=1):
    r"""Smooth (and optionally differentiate) data with a Savitzky-Golay filter.
    The Savitzky-Golay filter removes high frequency noise from data.
    It has the advantage of preserving the original shape and
    features of the signal better than other types of filtering
    approaches, such as moving averages techniques.
    Parameters
    ----------
    y : array_like, shape (N,)
        the values of the time history of the signal.
    window_size : int
        the length of the window. Must be an odd integer number.
    order : int
        the order of the polynomial used in the filtering.
        Must be less then `window_size` - 1.
    deriv: int
        the order of the derivative to compute (default = 0 means only smoothing)
    Returns
    -------
    ys : ndarray, shape (N)
        the smoothed signal (or it's n-th derivative).
    Notes
    -----
    The Savitzky-Golay is a type of low-pass filter, particularly
    suited for smoothing noisy data. The main idea behind this
    approach is to make for each point a least-square fit with a
    polynomial of high order over a odd-sized window centered at
    the point.
    Examples
    --------
    t = np.linspace(-4, 4, 500)
    y = np.exp( -t**2 ) + np.random.normal(0, 0.05, t.shape)
    ysg = savitzky_golay(y, window_size=31, order=4)
    import matplotlib.pyplot as plt
    plt.plot(t, y, label='Noisy signal')
    plt.plot(t, np.exp(-t**2), 'k', lw=1.5, label='Original signal')
    plt.plot(t, ysg, 'r', label='Filtered signal')
    plt.legend()
    plt.show()
    References
    ----------
    .. [1] A. Savitzky, M. J. E. Golay, Smoothing and Differentiation of
       Data by Simplified Least Squares Procedures. Analytical
       Chemistry, 1964, 36 (8), pp 1627-1639.
    .. [2] Numerical Recipes 3rd Edition: The Art of Scientific Computing
       W.H. Press, S.A. Teukolsky, W.T. Vetterling, B.P. Flannery
       Cambridge University Press ISBN-13: 9780521880688
    """
    import numpy as np
    from math import factorial

    try:
        window_size = np.abs(int(window_size))
        order = np.abs(int(order))
    except ValueError:
        raise ValueError("window_size and order have to be of type int")
    if window_size % 2 != 1 or window_size < 1:
        raise TypeError("window_size size must be a positive odd number")
    if window_size < order + 2:
        raise TypeError("window_size is too small for the polynomials order")
    order_range = range(order + 1)
    half_window = (window_size - 1) // 2
    # precompute coefficients
    b = np.mat([[k ** i for i in order_range] for k in range(-half_window, half_window + 1)])
    m = np.linalg.pinv(b).A[deriv] * rate ** deriv * factorial(deriv)
    # pad the signal at the extremes with
    # values taken from the signal itself
    firstvals = y[0] - np.abs(y[1:half_window + 1][::-1] - y[0])
    lastvals = y[-1] + np.abs(y[-half_window - 1:-1][::-1] - y[-1])
    y = np.concatenate((firstvals, y, lastvals))
    return np.convolve(m[::-1], y, mode='valid')


# %%
# functions for SNAIL calculation and flux sweep fit
# eliminate bad points in flux sweep data
def eliminate_bad_points(f_0, threshold=2e9):
    test = np.zeros(len(f_0))
    slope_array = np.gradient(f_0)
    for i in range(len(test)):
        if i == 0:
            test[i] = f_0[i]
        else:
            if np.abs(f_0[i] - f_0[i - 1]) > threshold:
                test[i] = np.nan
            else:
                test[i] = f_0[i]
    return test


# SNAIL coefficient calculation, here we only care about the 2nd order
# as we are trying to fit the data
def c1_coeff(phi, phi_ext, alpha):
    return (-np.sin((phi_ext - phi) / 3) + alpha * np.sin(phi))


def c2_coeff(phi, phi_ext, alpha):
    return ((np.cos((phi_ext - phi) / 3) / 3 + alpha * np.cos(phi)) / 2.0)


# Defining a solver to get the phi_min given a phi_ext
def solve(phi_ext, alpha, unsolved_value):
    phi_min = unsolved_value
    # Defining the coefficients from the taylor expansion of the SNAIL energy that we want to be 0
    c1 = c1_coeff(phi_min, phi_ext, alpha)
    return [
        c1
    ]


# Now for a given range of external flux values let's compute the corresponding phi_min's for the expansion
def phi_m(initial_guess, phi_ext, alpha):
    solver = partial(solve, phi_ext, alpha)
    phi_min = broyden1(solver, initial_guess, x_tol=1e-10, maxiter=100000)
    return phi_min


def L_j(initial_guess, phi_ext, alpha, L_0):
    phi_min = phi_m(initial_guess, phi_ext, alpha)
    c2 = c2_coeff(phi_min, phi_ext, alpha)
    L_j = L_0 / 2 / c2
    return L_j


def get_c2(initial_guess, phi_ext, alpha, L_0):
    phi_min = phi_m(initial_guess, phi_ext, alpha)
    c2 = c2_coeff(phi_min, phi_ext, alpha)
    return c2


# functions for fit flux sweep
def snail_calculation(L_0, alpha, phi_exts=None):
    if phi_exts is None:
        phi_exts = np.linspace(-1.5, 1.5, 501) * np.pi
    L_js = np.zeros(len(phi_exts))
    initial_guess = 0.0

    for i in range(len(phi_exts)):
        phi_ext = phi_exts[i]
        L_js[i] = L_j(initial_guess, phi_ext, alpha, L_0)

    return phi_exts, L_js


def flux_sweep_residual(pars, current, data=None):
    """Model a pumped avoid crossing data"""
    vals = pars.valuesdict()
    L_linear = vals['L_linear']
    c_snail = vals['c_snail']
    I_0 = vals['I_0']
    I_1 = vals['I_1']
    L_0 = vals['L_0']
    alpha = vals['alpha']

    # L_0 = 2.5e-9
    # alpha = 0.29

    flux_from_current = -(current - I_0) * 0.5 / (I_1 - I_0)
    phi_exts, L_js = snail_calculation(L_0, alpha, phi_exts=flux_from_current * 2 * np.pi)

    cal_results = np.zeros(len(flux_from_current))
    for i in range(len(flux_from_current)):
        cal_results[i] = 1.0 / np.sqrt(c_snail * (np.abs(L_js[i]) + L_linear)) / 2 / np.pi / 1e9

    if data is None:
        return cal_results

    residual_result = (cal_results - data).flatten()
    return residual_result.view(float)


# %%
# function for solving S-matrix
def solve_S_matrix(probe_detune=0,
                   d1=0,
                   a11=0,
                   a12=0,
                   a21=0,
                   a22=0,
                   d2=0,
                   a31=0,
                   a32=0,
                   gp1=0.5,
                   gp2=0.5,
                   phip1=0,
                   phip2=np.pi / 2,
                   gamma1=1.0,
                   ki1=0.0,
                   gamma2=1.0,
                   ki2=0.0,
                   gammab=1.0,
                   kbi=0.0,
                   g0=-1.0,
                   phi0=np.pi / 2,
                   epsilon=0,
                   dLin=1,
                   dRin=0,
                   b_in=0):
    '''
    The meaning of the inputs:
        d1 - probe detune for a1/a2 mode
        d2 - probe detune for b mode
        aij - stark shift due to the pump, i and j means it comes from the Kerr from jth mode to ith mode
        gp1 - effective g value from pump between a1 and b
        gp2 - effective g value from pump between a2 and b
        phip1 - phase of the pump between a1 and b
        phip2 - phase of the pump between a2 and b
        gamma1/2 - external kappa value of the a1/2 mode
        ki1/2 - internal kappa value of the a1/2 mode
        gammab - external kappa value of the b mode
        g0 - strength of the cancellation from the quantum bus like interaction between b and a1/a2 mode
        phi0 - phase comes from the seperation between a1 and a2 mode on the transmission line
        epsilon - ratio for pump leakage
        dLin - input signal from left port of the transmission line
        dRin - input signal from right port of the tranmission line
        b_in - input signal from b port

    '''
    ## in the form of [a1, a2, b, bout, dLout, dRout]
    eqs = np.array([[1j * (probe_detune + d1 + a11 + a12) + (2 * gamma1 + ki1) / 2.0,
                     1j * g0 + np.sqrt(gamma1 * gamma2) * np.exp(1j * phi0),
                     1j * (gp1 * np.exp(1j * phip1) + epsilon * gp2 * np.exp(1j * phip2)),
                     0,
                     0,
                     0],
                    [1j * g0 + np.sqrt(gamma1 * gamma2) * np.exp(1j * phi0),
                     1j * (probe_detune + d2 + a21 + a22) + (2 * gamma2 + ki2) / 2.0,
                     1j * (epsilon * gp1 * np.exp(1j * phip1) + gp2 * np.exp(1j * phip2)),
                     0,
                     0,
                     0],
                    [1j * gp1 * np.exp(-1j * phip1),
                     1j * gp2 * np.exp(-1j * phip2),
                     1j * (probe_detune + a31 + a32) + (gammab + kbi) / 2.0,
                     0,
                     0,
                     0],
                    [-np.sqrt(gamma1),
                     -np.sqrt(gamma2) * np.exp(1j * phi0),
                     0,
                     0,
                     1,
                     0],
                    [-np.sqrt(gamma1),
                     -np.sqrt(gamma2) * np.exp(-1j * phi0),
                     0,
                     0,
                     0,
                     1],
                    [0,
                     0,
                     -np.sqrt(gammab),
                     1,
                     0,
                     0]])

    vals = np.array([-np.sqrt(gamma1) * dLin - np.sqrt(gamma1) * dRin,
                     -np.sqrt(gamma2) * np.exp(-1j * phi0) * dLin - np.sqrt(gamma2) * np.exp(1j * phi0) * dRin,
                     -np.sqrt(gammab) * b_in,
                     dLin,
                     dRin,
                     b_in])

    return np.linalg.solve(eqs, vals)


# %%
# fit function for the coupling cancellation

def residual_coupling_fit_model(pars, f_probe, kappas, data=None):
    vals = pars.valuesdict()

    g_c_ratio = vals['g_c_ratio']
    phi0_ratio = vals['phi0_ratio']
    phase_offset = vals['phase_offset']
    f_0 = vals['f_0']

    gamma1 = kappas['gamma1']
    gamma2 = kappas['gamma2']
    gammab = kappas['gammab']
    ki1 = kappas['ki1']
    ki2 = kappas['ki2']
    kbi = kappas['kbi']

    cal_result = np.zeros((len(f_probe), 6), dtype=complex)

    for i in range(len(f_probe)):
        cal_result[i, :] = solve_S_matrix(probe_detune=f_probe[i] - f_0, d2=0.0,
                                          gamma1=gamma1, ki1=ki1,
                                          gamma2=gamma2, ki2=ki2,
                                          gammab=gammab, kbi=kbi,
                                          gp1=0, gp2=0,
                                          g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio,
                                          phi0=np.pi * 0.5 * phi0_ratio)

    if data is None:
        return cal_result[:, 4] * np.exp(1.0j * phase_offset)

    residual_result = (cal_result[:, 4] * np.exp(1.0j * phase_offset) - data).flatten()
    return residual_result.view(float)


# %%
# functions for processing pump induced avoid crossing data
def pump_avoid_crossing_residual(pars, probe_freq, pump_freq, data=None):
    """Model a pumped avoid crossing data"""
    vals = pars.valuesdict()
    w_s = vals['w_s'] * 2 * np.pi
    w_r = vals['w_r'] * 2 * np.pi
    g = vals['g'] * 2 * np.pi
    kappa_a = vals['kappa_a'] * 2 * np.pi
    kappa_b_e = vals['kappa_b_e'] * 2 * np.pi
    kappa_b_i = vals['kappa_b_i'] * 2 * np.pi
    amp_adjust_mag = vals['amp_adjust_mag']
    amp_adjust_phase = vals['amp_adjust_phase']

    amp_adjust = amp_adjust_mag * np.exp(1j * amp_adjust_phase)

    cal_results = np.zeros((len(pump_freq), len(probe_freq)), dtype=complex)

    for i in range(len(pump_freq)):
        coeff = -1j * (probe_freq - w_s) + g ** 2 / (-1j * (probe_freq - pump_freq[i] - w_r) - kappa_a / 2) - (
                kappa_b_e + kappa_b_i) / 2
        cal_results[i] = amp_adjust * (coeff + kappa_b_e) / coeff

    if data is None:
        return cal_results

    residual_result = (cal_results - data).flatten()
    return residual_result.view(float)


# plot function for fit
def plot_avoid_crossing(frequency_, frequency, smooth_data, fit_data):
    fig, ax = plt.subplots(2, 1, sharex=True, sharey=True)
    pcm0 = ax[0].pcolormesh(frequency_ / 1e9, frequency[0] / 1e9, np.abs(smooth_data.transpose()),
                            vmin=np.abs(smooth_data.transpose()).min(),
                            vmax=np.abs(smooth_data.transpose()).max())
    # ax[0].set_xlabel('pump frequency (GHz)')
    ax[0].set_ylabel('Probe Freq(GHz)')
    fig.colorbar(pcm0, ax=ax[0])

    pcm1 = ax[1].pcolormesh(frequency_ / 1e9, frequency[0] / 1e9, np.abs(fit_data.transpose()),
                            vmin=np.abs(smooth_data.transpose()).min(),
                            vmax=np.abs(smooth_data.transpose()).max())
    ax[1].set_xlabel('Pump Freq (GHz)')
    ax[1].set_ylabel('Probe Freq(GHz)')
    fig.colorbar(pcm1, ax=ax[1])


# smooth function for the complex data
def get_smooth_avoid_crossing_data(response):
    smooth_data = np.zeros(response.shape, dtype=complex)
    for i in range(response.shape[0]):
        smooth_data[i] = savitzky_golay(response[i].real, 51, 3) + 1j * savitzky_golay(response[i].imag, 51, 3)
    return smooth_data


def load_single_QICK_pump_data(fn):
    data = datadict_from_hdf5(fn)
    frequency = data.data_vals('frequency')
    response = data.data_vals('trace')
    try:
        pump_freq = data.data_vals('QICK_pump1_fre')
    except:
        pump_freq = data.data_vals('QICK_pump2_fre')
    return frequency, pump_freq * 1e6, response


# batch fit avoid crossing data
def batch_fitting_avoid_crossing(filenames, fit_params):
    fit_results = []

    for i in range(len(filenames)):
        print(i)
        frequency, pump_freq, response = load_single_QICK_pump_data(filenames[i])
        smooth_data = get_smooth_avoid_crossing_data(response)
        out = minimize(pump_avoid_crossing_residual, fit_params,
                       args=(frequency[0] * 2 * np.pi / 1e6, pump_freq * 2 * np.pi / 1e6),
                       kws={'data': smooth_data})
        fit_results.append(out)
    return fit_results


def plot_batch_fitting_results(filenames, fit_results):
    for i in range(len(filenames)):
        frequency, pump_freq, response = load_single_QICK_pump_data(filenames[i])
        smooth_data = get_smooth_avoid_crossing_data(response)

        fit_data = pump_avoid_crossing_residual(fit_results[i].params, frequency[0] * 2 * np.pi / 1e6,
                                                pump_freq * 2 * np.pi / 1e6)
        plot_avoid_crossing(pump_freq, frequency, smooth_data, fit_data)

    return None


# linear fit for g vs pump voltage
def linear_residual(pars, x, data=None):
    """Model a pumped avoid crossing data"""
    vals = pars.valuesdict()
    k = vals['k']
    b = vals['b']

    cal_results = k * x + b

    if data is None:
        return cal_results

    residual_result = (cal_results - data).flatten()
    return residual_result.view(float)


def fit_geff_vs_pump_voltage(g_fit_results, pump_voltage, pump_fit_params):
    g_values = []
    for fit_result in g_fit_results:
        g_values.append(fit_result.params['g'].value)

    fit_result = minimize(linear_residual, pump_fit_params, args=(pump_voltage,), kws={'data': g_values})
    return fit_result


def plot_geff_fit(linear_fit_result, g_fit_results, pump_voltage):
    g_values = []
    for fit_result in g_fit_results:
        g_values.append(fit_result.params['g'].value)

    fit_data = linear_residual(linear_fit_result.params, pump_voltage)
    fig, ax = plt.subplots(1, 1)
    ax.plot(pump_voltage, g_values, 'o', label='data')
    ax.plot(pump_voltage, fit_data, '-', label='linear fit')
    ax.legend(loc=0)
    ax.set_xlabel('DAC')
    ax.set_ylabel('Effective g (MHz)')
    return g_values, fit_data


# %% function for processing isoaltion and gyration data

# plot the data and fit at the pass and block point
def plot_pass_block_data(f_pass, normalized_pass_data, f_block, normalized_block_data, f_center,
                         calculate_results=None):
    fontsize = 14
    if calculate_results is not None:
        deltas = calculate_results['deltas']
        results_pass = calculate_results['results_pass']
        results_block = calculate_results['results_block']

    fig, ax = plt.subplots(2, 2, figsize=(4 * 2, 2 * 2), sharex=True, sharey=False)

    ax[0][0].plot(f_pass - f_center, 10 * np.log10(np.abs(normalized_pass_data) ** 2),
                  'o', label='data', markersize=1.0)
    ax[0][0].set_ylabel('S21 Mag', fontsize=fontsize)
    ax[0][0].tick_params(axis='y', which='major', labelsize=fontsize)

    ax[1][0].plot(f_pass - f_center, np.angle(normalized_pass_data * np.exp(0.0j * np.pi)),
                  'o', label='data', markersize=1.0)
    ax[1][0].set_xlabel('Probe detune (MHz)', fontsize=fontsize)
    ax[1][0].tick_params(axis='x', which='major', labelsize=fontsize)
    ax[1][0].set_ylabel('S21 phase (rad)', fontsize=fontsize)
    ax[1][0].tick_params(axis='y', which='major', labelsize=fontsize)
    ax[1][0].set_ylim([-3.5, 3.5])

    ax[0][1].plot(f_pass - f_center, 10 * np.log10(np.abs(normalized_block_data) ** 2),
                  'o', label='data', markersize=1.0)
    ax[0][1].set_ylabel('S21 Mag', fontsize=fontsize)
    ax[0][1].tick_params(axis='y', which='major', labelsize=fontsize)

    ax[1][1].plot(f_pass - f_center, np.angle(normalized_block_data * np.exp(0.0j * np.pi)),
                  'o', label='data', markersize=1.0)
    ax[1][1].set_xlabel('Probe detune (MHz)', fontsize=fontsize)
    ax[1][1].set_ylabel('S21 phase (rad)', fontsize=fontsize)
    ax[1][1].tick_params(axis='x', which='major', labelsize=fontsize)
    ax[1][1].tick_params(axis='y', which='major', labelsize=fontsize)
    ax[1][1].set_ylim([-3.5, 3.5])

    if calculate_results is not None:
        ax[0][0].plot(deltas, 10 * np.log10(np.abs(results_pass) ** 2), label='model', linewidth=2)
        ax[1][0].plot(deltas, np.angle(results_pass), label='model', linewidth=2)
        ax[0][1].plot(deltas, 10 * np.log10(np.abs(results_block) ** 2), label='model', linewidth=2)
        ax[1][1].plot(deltas, np.angle(results_block), label='model', linewidth=2)

    for i in range(2):
        for j in range(2):
            ax[i][j].legend(fontsize=fontsize)
            ax[i][j].grid(linewidth=1)

    return None


# plot the gyration data
def plot_gyration_data(phase, data, pump_phase, pump_phase_offset, calculate_results, amp_offset):
    fig, ax = plt.subplots(1, 1, figsize=(4, 3))
    ax.plot(phase, data, 'o', label='data')
    ax.plot(pump_phase * 180 / np.pi + pump_phase_offset, 10 * np.log10(calculate_results ** 2) + amp_offset,
            '-', label='model', linewidth=2)
    ax.set_xlim([-10, 370])
    ax.set_xlabel('Pump phase (deg)')
    ax.set_ylabel('RT output power at b port (dBm)')
    ax.legend(loc=0)

    return None


# calculate the isolation and gyration data
def isolation_gyration_calculation(g_pump, g_pump_ratio, g_c_ratio, phi0_ratio, freq_detune, gammas):
    gamma1 = gammas['gamma1']
    gamma2 = gammas['gamma2']
    gammab = gammas['gammab']
    ki1 = gammas['ki1']
    ki2 = gammas['ki2']
    kbi = gammas['kbi']

    deltas = np.linspace(-5, 5, 501)
    pump_phases = np.linspace(-2 * np.pi, 2 * np.pi, 201)
    calculate_results = np.zeros((len(deltas), len(pump_phases), 6), dtype=complex)

    results_pass = np.zeros((len(deltas), 6), dtype=complex)
    results_block = np.zeros((len(deltas), 6), dtype=complex)

    freq_detune1 = freq_detune
    freq_detune2 = freq_detune

    for i in range(len(deltas)):
        results_pass[i] = solve_S_matrix(probe_detune=deltas[i], d1=freq_detune1, d2=freq_detune2, gamma1=gamma1,
                                         ki1=ki1,
                                         gamma2=gamma2, ki2=ki2, gammab=gammab, kbi=kbi,
                                         gp1=g_pump, gp2=g_pump * g_pump_ratio,
                                         g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio, phi0=np.pi * 0.5 * phi0_ratio,
                                         phip2=np.pi / 2)
        results_block[i] = solve_S_matrix(probe_detune=deltas[i], d1=freq_detune1, d2=freq_detune2, gamma1=gamma1,
                                          ki1=ki1,
                                          gamma2=gamma2, ki2=ki2, gammab=gammab, kbi=kbi,
                                          gp1=g_pump, gp2=g_pump * g_pump_ratio,
                                          g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio, phi0=np.pi * 0.5 * phi0_ratio,
                                          phip2=-np.pi / 2)

    for j in range(len(pump_phases)):
        for i in range(len(deltas)):
            calculate_results[i, j] = solve_S_matrix(probe_detune=deltas[i], d1=freq_detune1, d2=freq_detune2,
                                                     gamma1=gamma1, ki1=ki1, gamma2=gamma2, ki2=ki2, gammab=gammab,
                                                     kbi=kbi,
                                                     gp1=g_pump, gp2=g_pump * g_pump_ratio,
                                                     g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio,
                                                     phi0=np.pi * 0.5 * phi0_ratio,
                                                     phip2=pump_phases[j])

    return deltas, pump_phases, results_pass, results_block, calculate_results


def get_isolation_gyration_data(fn, fn_background, sa_fn):
    frequency, phase, response = load_pump_phase_sweep_data(fn)
    f_background, data_background = load_single_vna_data(fn_background)
    normalized_response = response / np.mean(np.abs(data_background))

    # get gyration data
    sa_freq, phase, sa_trace = load_sa_data(sa_fn)
    sa_results = np.zeros(len(phase))

    for i in range(len(phase)):
        sa_results[i] = np.max(sa_trace[i, :])

    return frequency, phase, normalized_response, sa_results


# plot the full isolation sweep data
def plot_full_isolation_data(frequency, f_center, phase, data,
                             deltas, pump_phases, pump_phase_offset, calculate_results):
    fig, ax = plt.subplots(2, 2, figsize=(3.5, 1.5 * 2))
    pcm0 = ax[0][0].pcolormesh(frequency - f_center, phase, 10 * np.log10(np.abs(data) ** 2),
                               vmin=-40.0,
                               vmax=1.0)
    ax[0][0].set_xlabel('Frequency (MHz)')
    ax[0][0].set_ylabel('Pump phase (Deg)')
    ax[0][0].set_title('Data Magnitude')
    fig.colorbar(pcm0, ax=ax[0][0])

    pcm0 = ax[1][0].pcolormesh(frequency - f_center, phase, np.angle(data),
                               vmin=-np.pi,
                               vmax=np.pi)
    ax[1][0].set_xlabel('Frequency (MHz)')
    ax[1][0].set_ylabel('Pump phase (Deg)')
    ax[1][0].set_title('Data Phase')
    fig.colorbar(pcm0, ax=ax[1][0])

    pcm1 = ax[0][1].pcolormesh(deltas, pump_phases * 180.0 / np.pi + pump_phase_offset,
                               10 * np.log10(np.abs(calculate_results[:, :, 4].transpose()) ** 2),
                               vmin=-40.0,
                               vmax=1.0)
    ax[0][1].set_xlabel('pump frequency (GHz)')
    ax[0][1].set_ylabel('probe frequency (GHz)')
    ax[0][1].set_title('Model Magnitude')
    ax[0][1].set_ylim([0, 360])
    fig.colorbar(pcm1, ax=ax[0][1])

    pcm1 = ax[1][1].pcolormesh(deltas, pump_phases * 180.0 / np.pi + pump_phase_offset,
                               np.angle(calculate_results[:, :, 4].transpose()),
                               vmin=-np.pi,
                               vmax=np.pi)
    ax[1][1].set_xlabel('pump frequency (GHz)')
    ax[1][1].set_ylabel('probe frequency (GHz)')
    ax[1][1].set_title('Model Phase')
    ax[1][1].set_ylim([0, 360])
    fig.colorbar(pcm1, ax=ax[1][1])


# %%
# functions used to calculate the anticipate performance of the device

def calculate_full_S_matrix(deltas, gammas, g_pump=0.5, g_pump_ratio=1.0, g_c_ratio=1.0, phi0_ratio=1.0,
                            freq_detune=0.0):
    gamma1 = gammas['gamma1']
    gamma2 = gammas['gamma2']
    gammab = gammas['gammab']
    ki1 = gammas['ki1']
    ki2 = gammas['ki2']
    kbi = gammas['kbi']

    results_pass = np.zeros((3, len(deltas), 6), dtype=complex)
    results_block = np.zeros((3, len(deltas), 6), dtype=complex)

    freq_detune1 = freq_detune
    freq_detune2 = freq_detune

    # theory mode
    for i in range(len(deltas)):
        results_pass[0, i] = solve_S_matrix(probe_detune=deltas[i], d1=freq_detune1, d2=freq_detune2, gamma1=gamma1,
                                            ki1=ki1,
                                            gamma2=gamma2, ki2=ki2, gammab=gammab, kbi=kbi, gp1=g_pump,
                                            gp2=g_pump * g_pump_ratio,
                                            g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio, phi0=np.pi * 0.5 * phi0_ratio,
                                            phip2=np.pi / 2,
                                            dLin=1.0, dRin=0.0, b_in=0.0)
        results_block[0, i] = solve_S_matrix(probe_detune=deltas[i], d1=freq_detune1, d2=freq_detune2, gamma1=gamma1,
                                             ki1=ki1,
                                             gamma2=gamma2, ki2=ki2, gammab=gammab, kbi=kbi, gp1=g_pump,
                                             gp2=g_pump * g_pump_ratio,
                                             g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio, phi0=np.pi * 0.5 * phi0_ratio,
                                             phip2=-np.pi / 2,
                                             dLin=1.0, dRin=0.0, b_in=0.0)

    for i in range(len(deltas)):
        results_pass[1, i] = solve_S_matrix(probe_detune=deltas[i], d1=freq_detune1, d2=freq_detune2, gamma1=gamma1,
                                            ki1=ki1,
                                            gamma2=gamma2, ki2=ki2, gammab=gammab, kbi=kbi, gp1=g_pump,
                                            gp2=g_pump * g_pump_ratio,
                                            g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio, phi0=np.pi * 0.5 * phi0_ratio,
                                            phip2=np.pi / 2,
                                            dLin=0.0, dRin=1.0, b_in=0.0)
        results_block[1, i] = solve_S_matrix(probe_detune=deltas[i], d1=freq_detune1, d2=freq_detune2, gamma1=gamma1,
                                             ki1=ki1,
                                             gamma2=gamma2, ki2=ki2, gammab=gammab, kbi=kbi, gp1=g_pump,
                                             gp2=g_pump * g_pump_ratio,
                                             g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio, phi0=np.pi * 0.5 * phi0_ratio,
                                             phip2=-np.pi / 2,
                                             dLin=0.0, dRin=1.0, b_in=0.0)

    for i in range(len(deltas)):
        results_pass[2, i] = solve_S_matrix(probe_detune=deltas[i], d1=freq_detune1, d2=freq_detune2, gamma1=gamma1,
                                            ki1=ki1,
                                            gamma2=gamma2, ki2=ki2, gammab=gammab, kbi=kbi, gp1=g_pump,
                                            gp2=g_pump * g_pump_ratio,
                                            g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio, phi0=np.pi * 0.5 * phi0_ratio,
                                            phip2=np.pi / 2,
                                            dLin=0.0, dRin=0.0, b_in=1.0)
        results_block[2, i] = solve_S_matrix(probe_detune=deltas[i], d1=freq_detune1, d2=freq_detune2, gamma1=gamma1,
                                             ki1=ki1,
                                             gamma2=gamma2, ki2=ki2, gammab=gammab, kbi=kbi, gp1=g_pump,
                                             gp2=g_pump * g_pump_ratio,
                                             g0=-np.sqrt(gamma1 * gamma2) * g_c_ratio, phi0=np.pi * 0.5 * phi0_ratio,
                                             phip2=-np.pi / 2,
                                             dLin=0.0, dRin=0.0, b_in=1.0)

    return deltas, results_pass, results_block


def plot_full_S_matrix(deltas, results_pass, results_block):
    labels = [[r'$S_{11}$', r'$S_{21}$', r'$S_{b1}$'],
              [r'$S_{12}$', r'$S_{22}$', r'$S_{b2}$'],
              [r'$S_{1b}$', r'$S_{2b}$', r'$S_{bb}$']]

    def index_change(index):
        if index == 0:
            return 5
        elif index == 1:
            return 4
        elif index == 2:
            return 3

    fig, ax = plt.subplots(3, 3, figsize=(3 * 2, 2 * 2), sharex=True, sharey=True)
    for j in range(3):
        for i in range(3):
            ax[j][i].plot(deltas, np.abs(results_pass[j, :, index_change(i)]), label='CW ' + labels[j][i])
            ax[j][i].plot(deltas, np.abs(results_block[j, :, index_change(i)]), '--', label='CCW ' + labels[j][i])
            ax[j][i].legend()


# %%
# functiontion to calculate

def input_wave(gamma_ph, t):
    return 1j * (0.5 * np.sqrt(gamma_ph) / (np.cosh(gamma_ph * (t) / 2.0)))


def g_t(kappa, gamma_ph, t):
    gamma = kappa
    x = gamma_ph * t / 2.0
    Cosh = (np.exp(x) + np.exp(-x)) / 2.0
    Tanh = (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))

    up = -gamma_ph * np.sqrt(gamma_ph) * Tanh + 2 * gamma * np.sqrt(gamma_ph)
    down = -8 * np.sqrt(-gamma_ph / (8) + gamma * (
            (np.exp(x) + np.exp(-3 * x) + 2 * np.exp(-x)) / (np.exp(x) + np.exp(-x))) / 2.0 / 2.0)
    return (up / down)


def langevin_emission_symmetric_vary_g(t, y, g_0, kappa, phi_0, phi_a, phi_c, gamma_ph, dl_in=0, dr_in=0, kappa_int=0,
                                       pitch=True):
    if pitch:
        g_time = g_t(kappa, gamma_ph, t)
    else:
        g_time = g_t(kappa, gamma_ph, -t)
    kappa_b = 0
    b_in = 0

    if dl_in != 0:
        dl_in = input_wave(gamma_ph, t)
    if dr_in != 0:
        dr_in = input_wave(gamma_ph, t)

    phi_d1 = 0
    phi_d2 = 0

    def _adot(a, b, c):
        return -1j * g_time * np.exp(1j * phi_a) * b - kappa * a - kappa * 1j * np.sin(phi_0) * c - np.sqrt(
            kappa) * dl_in - np.sqrt(kappa) * dr_in - 1j * g_0 * c - kappa_int * a / 2

    def _bdot(a, b, c):
        return -1j * g_time * np.exp(-1j * phi_a) * a - 1j * g_time * np.exp(-1j * phi_c) * c

    def _cdot(a, b, c):
        return -1j * g_time * np.exp(1j * phi_c) * b - kappa * c - kappa * 1j * np.sin(phi_0) * a - np.sqrt(
            kappa) * np.exp(-1j * phi_0) * dl_in \
               - np.sqrt(kappa) * np.exp(1j * phi_0) * dr_in - 1j * g_0 * a - kappa_int * c / 2

    a, b, c = y
    return [_adot(a, b, c), _bdot(a, b, c), _cdot(a, b, c)]


def solve_emission_symmetric_time_vary_g(tvals, g0, kappa, phi_0, phi_a, phi_c, gamma_ph,
                                         init_cond=[0 + 0j, 1 + 0j, 0 + 0j], dl_in=0, dr_in=0, kappa_int=0, pitch=True):
    sol = solve_ivp(
        langevin_emission_symmetric_vary_g,  # equation system
        [tvals[0], tvals[-1]],  # time bounds
        init_cond,  # initial populations
        t_eval=tvals,
        method='RK45',
        args=[g0, kappa, phi_0, phi_a, phi_c, gamma_ph, dl_in, dr_in, kappa_int, pitch]
    )

    if dl_in != 0:
        dl_in = input_wave(gamma_ph, tvals)
    if dr_in != 0:
        dr_in = input_wave(gamma_ph, tvals)

    # get solution
    asol, bsol, csol = sol.y
    dl_out = np.sqrt(kappa) * asol + np.sqrt(kappa) * np.exp(1j * phi_0) * csol + dl_in
    dr_out = np.sqrt(kappa) * asol + np.sqrt(kappa) * np.exp(-1j * phi_0) * csol + dr_in

    return dict(a=asol, b=bsol, c=csol, left=dl_out, right=dr_out)


def chiral_coupler_pitch_and_catch(kappa, gamma_ph, tvals, coupling_cancel_ratio=-1, init_cond=[0 + 0j, 1 + 0j, 0 + 0j],
                                   dl_in=0, dr_in=0, kappa_int=0, pitch=True):
    g0 = coupling_cancel_ratio * kappa
    t_final = tvals[-1]

    # Define the phase of each pump
    phi_0 = 0.5 * np.pi
    phi_a = 0 * np.pi
    phi_c = -0.5 * np.pi

    phi_cs = np.array([-0.5, 0.5]) * np.pi

    perfect_sols = {}
    for phi_c in phi_cs:
        perfect_sols[phi_c] = solve_emission_symmetric_time_vary_g(
            tvals, g0, kappa, phi_0, phi_a, phi_c, gamma_ph,
            init_cond=init_cond, dl_in=dl_in, dr_in=dr_in, kappa_int=kappa_int, pitch=pitch
        )

    return perfect_sols


# %%
# SNAIL calculation for estimating the effect of alpha

phi_0 = 2.067833848e-15 / (2 * np.pi)  # the reduced flux quantum
h = 6.62607015e-34  # planck's constant
hbar = h / (2 * np.pi)  # reduced planck's constant


# #Redefine coefficients for clarity

# Redefining c1 and c2 for clarity
def c1_coeff(phi, phi_ext, alpha):
    return (-np.sin((phi_ext - phi) / 3) + alpha * np.sin(phi))


def c2_coeff(phi, phi_ext, alpha):
    return ((np.cos((phi_ext - phi) / 3) / 3 + alpha * np.cos(phi)) / 2.0)


def c3_coeff(phi, phi_ext, alpha):
    return ((np.sin((phi_ext - phi) / 3) / 9 - alpha * np.sin(phi)) / 6.0)


def c4_coeff(phi, phi_ext, alpha):
    return ((-np.cos((phi_ext - phi) / 3) / 27 - alpha * np.cos(phi)) / 24.0)


# Defining a solver to get the phi_min given a phi_ext
def solve(phi_ext, alpha, unsolved_value):
    phi_min = unsolved_value
    # Defining the coefficients from the taylor expansion of the SNAIL energy that we want to be 0
    c1 = c1_coeff(phi_min, phi_ext, alpha)
    return [
        c1
    ]


# Now for a given range of external flux values let's compute the corresponding phi_min's for the expansion


def phi_m(initial_guess, phi_ext, alpha):
    solver = partial(solve, phi_ext, alpha)
    phi_min = broyden1(solver, initial_guess, x_tol=1e-10, maxiter=100000)
    return phi_min


def L_j(initial_guess, phi_ext, alpha, L_0):
    phi_min = phi_m(initial_guess, phi_ext, alpha)
    c2 = c2_coeff(phi_min, phi_ext, alpha)
    L_j = L_0 / 2 / c2
    return L_j


def get_c2(initial_guess, phi_ext, alpha, L_0):
    phi_min = phi_m(initial_guess, phi_ext, alpha)
    c2 = c2_coeff(phi_min, phi_ext, alpha)
    return c2


def get_c3(initial_guess, phi_ext, alpha):
    phi_min = phi_m(initial_guess, phi_ext, alpha)
    c3 = c3_coeff(phi_min, phi_ext, alpha)
    return c3


def get_c4(initial_guess, phi_ext, alpha):
    phi_min = phi_m(initial_guess, phi_ext, alpha)
    c4 = c4_coeff(phi_min, phi_ext, alpha)
    return c4


def snail_all_coeff_calculation(L_0, alpha, phi_exts=None):
    if phi_exts is None:
        phi_exts = np.linspace(-1.5, 1.5, 501) * np.pi

    L_js = np.zeros(len(phi_exts))
    c2 = np.zeros(len(phi_exts))
    c3 = np.zeros(len(phi_exts))
    c4 = np.zeros(len(phi_exts))
    initial_guess = 0.0

    for i in range(len(phi_exts)):
        phi_ext = phi_exts[i]
        L_js[i] = L_j(initial_guess, phi_ext, alpha, L_0)
        c2[i] = get_c2(initial_guess, phi_ext, alpha, L_0)
        c3[i] = get_c3(initial_guess, phi_ext, alpha)
        c4[i] = get_c4(initial_guess, phi_ext, alpha)

    return phi_exts, L_js, c2, c3, c4


# %%
# for calculating quantum state transfer

# set up the two qubit state transfer Hamiltonian
#
def two_qubit_transfer_H(parameters):
    def g_t1(t, args):
        gamma = args['gamma']
        gamma_ph = args['gamma_ph']
        x = gamma_ph * t / 2.0
        Cosh = (np.exp(x) + np.exp(-x)) / 2.0
        Tanh = (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))

        up = -gamma_ph * np.sqrt(gamma_ph) * Tanh + 2 * gamma * np.sqrt(gamma_ph)
        down = -8 * np.sqrt(-gamma_ph / 8 + gamma * (1 - Tanh) * Cosh ** 2 / 2.0)
        return up / down

    def g_t2(t, args):
        gamma = args['gamma']
        gamma_ph = args['gamma_ph']
        t = -t
        x = gamma_ph * t / 2.0
        Cosh = (np.exp(x) + np.exp(-x)) / 2.0
        Tanh = (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))

        up = -gamma_ph * np.sqrt(gamma_ph) * Tanh + 2 * gamma * np.sqrt(gamma_ph)
        down = -8 * np.sqrt(-gamma_ph / 8 + gamma * (1 - Tanh) * Cosh ** 2 / 2.0)
        return up / down

    dim = 2
    dim_b = 7
    a1 = qt.tensor(qt.destroy(dim), qt.qeye(dim), qt.qeye(dim_b), qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b))
    a2 = qt.tensor(qt.qeye(dim), qt.destroy(dim), qt.qeye(dim_b), qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b))
    b1 = qt.tensor(qt.qeye(dim), qt.qeye(dim), qt.destroy(dim_b), qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b))

    a3 = qt.tensor(qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b), qt.destroy(dim), qt.qeye(dim), qt.qeye(dim_b))
    a4 = qt.tensor(qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b), qt.qeye(dim), qt.destroy(dim), qt.qeye(dim_b))
    b2 = qt.tensor(qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b), qt.qeye(dim), qt.qeye(dim), qt.destroy(dim_b))

    a_ops = [a1, a2, a3, a4]
    # kappa = 5.0 * 2 * np.pi
    kappa = parameters['kappa']

    phi1 = 0
    phi2 = np.pi / 2

    phi_dis = np.pi

    phi3 = phi2 + phi_dis
    phi4 = phi3 + np.pi / 2

    phis = np.array([phi1, phi2, phi3, phi4])

    phi_g11 = 0
    phi_g12 = np.pi / 2

    phi_g21 = 0
    phi_g22 = np.pi / 2

    args = parameters

    H_sys1 = - kappa * (a1.dag() * a2 + a1 * a2.dag()) - kappa * (a3.dag() * a4 + a3 * a4.dag())
    H1 = np.exp(1j * phi_g11) * a1.dag() * b1 + np.exp(-1j * phi_g11) * a1 * b1.dag() + np.exp(
        1j * phi_g12) * a2.dag() * b1 + np.exp(-1j * phi_g12) * a2 * b1.dag()
    H2 = np.exp(1j * phi_g21) * a3.dag() * b2 + np.exp(-1j * phi_g21) * a3 * b2.dag() + np.exp(
        1j * phi_g22) * a4.dag() * b2 + np.exp(-1j * phi_g22) * a4 * b2.dag()

    H_J = 0 * a1
    for i in range(4):
        for j in range(4):
            if j > i:
                H_J += kappa * ((np.exp(1j * (phis[j] - phis[i])) - np.exp(-1j * (phis[j] - phis[i]))) / 2j) * (
                        a_ops[i].dag() * a_ops[j]) \
                       + kappa * ((np.exp(-1j * (phis[j] - phis[i])) - np.exp(1j * (phis[j] - phis[i]))) / -2j) * (
                               a_ops[i] * a_ops[j].dag())

    c_ops = []
    for i in range(4):
        c_ops.append(np.sqrt(2 * kappa) * a_ops[i])

    for i in range(4):
        for j in range(4):
            if i != j:
                super_op = kappa * (np.exp(1j * np.abs(phis[j] - phis[i])) + np.exp(
                    -1j * np.abs(phis[j] - phis[i]))) * qt.lindblad_dissipator(a_ops[i], a_ops[j])
                c_ops.append(super_op)

    H = [H_sys1 + H_J, [H1, g_t1], [H2, g_t2]]
    # time = np.linspace(-5, 5, 1001)
    # dt = time[1] - time[0]
    return H, c_ops


# do the qutip simulation
def chiral_coupler_state_transfer(psi0, time, H, c_ops, args):
    output = qt.mesolve(H, psi0, time, c_ops, [], args)#, progress_bar=True)
    return output


# get Wigner function
def get_wigner_tomo(output):
    testx = np.linspace(-5, 5, 201)

    output_states = output.states
    init_rho = output_states[0].ptrace(2)
    final_rho = output_states[-1].ptrace(5)

    W_init = qt.wigner(init_rho, testx, testx)
    W_final = qt.wigner(final_rho, testx, testx)

    return (testx, W_init, W_final)


# %%
# Quantum state transfer with practical parameters

def get_chiral_coupler_H(kappa_ratios, kappa=5 * 2 * np.pi, a_T1=None, a_T2=None, b_T1=None, b_T2=None,
                         residual_coupling_ratio=1.0, pump_leakage_ratio=None, gamma_ph=None, dim=2, dim_b=2):
    def g_t1(t, args):
        gamma = args['gamma']
        gamma_ph = args['gamma_ph']
        x = gamma_ph * t / 2.0
        # Cosh = (np.exp(x) + np.exp(-x)) / 2.0
        Tanh = (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))

        up = -gamma_ph * np.sqrt(gamma_ph) * Tanh + 2 * gamma * np.sqrt(gamma_ph)
        down = -8 * np.sqrt(-gamma_ph / (8) + gamma * (
                (np.exp(x) + np.exp(-3 * x) + 2 * np.exp(-x)) / (np.exp(x) + np.exp(-x))) / 2.0 / 2.0)
        return (up / down)

    def g_t2(t, args):
        gamma = args['gamma']
        gamma_ph = args['gamma_ph']
        t = -t
        x = gamma_ph * t / 2.0
        # Cosh = (np.exp(x) + np.exp(-x)) / 2.0
        Tanh = (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))

        up = -gamma_ph * np.sqrt(gamma_ph) * Tanh + 2 * gamma * np.sqrt(gamma_ph)
        down = -8 * np.sqrt(-gamma_ph / (8) + gamma * (
                (np.exp(x) + np.exp(-3 * x) + 2 * np.exp(-x)) / (np.exp(x) + np.exp(-x))) / 2.0 / 2.0)
        return (up / down)

    # dim = 2
    # dim_b = 2
    a1 = qt.tensor(qt.destroy(dim), qt.qeye(dim), qt.qeye(dim_b), qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b))
    a2 = qt.tensor(qt.qeye(dim), qt.destroy(dim), qt.qeye(dim_b), qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b))
    b1 = qt.tensor(qt.qeye(dim), qt.qeye(dim), qt.destroy(dim_b), qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b))

    a3 = qt.tensor(qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b), qt.destroy(dim), qt.qeye(dim), qt.qeye(dim_b))
    a4 = qt.tensor(qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b), qt.qeye(dim), qt.destroy(dim), qt.qeye(dim_b))
    b2 = qt.tensor(qt.qeye(dim), qt.qeye(dim), qt.qeye(dim_b), qt.qeye(dim), qt.qeye(dim), qt.destroy(dim_b))

    a_ops = [a1, a2, a3, a4]

    phi1 = 0
    phi2 = np.pi / 2

    phi_dis = np.pi

    phi3 = phi2 + phi_dis
    phi4 = phi3 + np.pi / 2

    phis = np.array([phi1, phi2, phi3, phi4])

    kappa_a1 = kappa * kappa_ratios[0]
    kappa_a2 = kappa * kappa_ratios[1]
    kappa_a3 = kappa * kappa_ratios[2]
    kappa_a4 = kappa * kappa_ratios[3]
    kappa_b1 = 0
    kappa_b2 = 0

    kappa_as = np.array([kappa_a1, kappa_a2, kappa_a3, kappa_a4])
    kappa_bs = np.array([kappa_b1, kappa_b2])

    phi_g11 = 0
    phi_g12 = np.pi / 2

    phi_g21 = 0
    phi_g22 = np.pi / 2

    if gamma_ph == None:
        args = {'gamma_ph': kappa * 0.1, 'gamma': kappa}
    else:
        args = {'gamma_ph': gamma_ph, 'gamma': kappa}

    H_sys1 = - kappa * residual_coupling_ratio * (a1.dag() * a2 + a1 * a2.dag()) - kappa * residual_coupling_ratio * (
            a3.dag() * a4 + a3 * a4.dag())
    H1 = np.exp(1j * phi_g11) * a1.dag() * b1 + np.exp(-1j * phi_g11) * a1 * b1.dag() + np.exp(
        1j * phi_g12) * a2.dag() * b1 + np.exp(-1j * phi_g12) * a2 * b1.dag()
    H2 = np.exp(1j * phi_g21) * a3.dag() * b2 + np.exp(-1j * phi_g21) * a3 * b2.dag() + np.exp(
        1j * phi_g22) * a4.dag() * b2 + np.exp(-1j * phi_g22) * a4 * b2.dag()

    H_J = 0 * a1
    for i in range(4):
        for j in range(4):
            if j > i:
                H_J += np.sqrt(kappa_as[i] * kappa_as[j]) * (
                        (np.exp(1j * (phis[j] - phis[i])) - np.exp(-1j * (phis[j] - phis[i]))) / 2j) * (
                               a_ops[i].dag() * a_ops[j]) \
                       + np.sqrt(kappa_as[i] * kappa_as[j]) * (
                               (np.exp(-1j * (phis[j] - phis[i])) - np.exp(1j * (phis[j] - phis[i]))) / -2j) * (
                               a_ops[i] * a_ops[j].dag())

    c_ops = []
    for i in range(4):
        c_ops.append(np.sqrt(2 * kappa_as[i]) * a_ops[i])
        if a_T1 is not None:
            kappa_a_T1 = 1.0 / a_T1
            c_ops.append(np.sqrt(kappa_a_T1) * a_ops[i])
        if a_T2 is not None:
            kappa_a_T2 = 1.0 / a_T2
            c_ops.append(np.sqrt(kappa_a_T2) * a_ops[i].dag() * a_ops[i])

    if b_T1 is not None:
        kappa_b_T1 = 1.0 / b_T1
        c_ops.append(np.sqrt(kappa_b_T1) * b1)
        c_ops.append(np.sqrt(kappa_b_T1) * b2)

    if b_T2 is not None:
        kappa_b_T2 = 1.0 / b_T2
        c_ops.append(np.sqrt(kappa_b_T2) * b1.dag() * b1)
        c_ops.append(np.sqrt(kappa_b_T2) * b2.dag() * b2)

    for i in range(4):
        for j in range(4):
            if i != j:
                super_op = np.sqrt(kappa_as[i] * kappa_as[j]) * (np.exp(1j * np.abs(phis[j] - phis[i])) + np.exp(
                    -1j * np.abs(phis[j] - phis[i]))) * qt.lindblad_dissipator(a_ops[i], a_ops[j])
                c_ops.append(super_op)

    if pump_leakage_ratio is None:
        H = [H_sys1 + H_J, [H1, g_t1], [H2, g_t2]]
    else:
        H = [H_sys1 + H_J, [H1, g_t1], [H2, g_t2], [pump_leakage_ratio * H1, g_t2], [pump_leakage_ratio * H2, g_t1]]

    return (H, c_ops, args)


#%%
# get vna traces
def get_trace_for_batch_fitting(folder_path):
    fn_list = os.listdir(folder_path)
    vna_freqs = []
    vna_traces = []
    currents = []

    for fn in fn_list:
        vna_freq, vna_trace = load_single_vna_data(folder_path + '\\' + fn + "\\" + 'data.ddh5')
        vna_freqs.append(vna_freq)
        vna_traces.append(vna_trace)
        try:
            currents.append(float(fn.split("_")[5]))
        except:
            currents.append(float(fn.split("_")[4]))
    #         try:
    # #             currents.append(float(fn[-7:-3]))
    #             currents.append(float(fn[50:55]))
    #         except:
    # #             currents.append(float(fn[-6:-3]))
    #             currents.append(float(fn[50:54]))

    vna_freqs = np.array(vna_freqs)
    vna_traces = np.array(vna_traces)
    currents = np.array(currents)

    return vna_freqs, vna_traces, currents


def batch_fitting_hanger(folder_path):
    fit_results = []
    fit_results_nopump = []

    fn_list = os.listdir(folder_path)
    vna_freqs = []
    vna_traces = []

    for fn in fn_list:
        vna_freq, vna_trace = load_single_vna_data(folder_path + '\\' + fn + "\\" + 'data.ddh5')
        vna_freqs.append(vna_freq)
        vna_traces.append(vna_trace)

    vna_freqs = np.array(vna_freqs)
    vna_traces = np.array(vna_traces)

    for i in range(len(fn_list)):
        fit = HangerResponseBruno(vna_freqs[i], vna_traces[i])
        fit_result = fit.run(Q_i=1e4, Q_e_mag=3e3)
        fit_results.append(fit_result)

    f_0s = np.zeros(len(fn_list))
    Q_es = np.zeros(len(fn_list))
    Q_is = np.zeros(len(fn_list))
    for i in range(len(fn_list)):
        f_0s[i] = fit_results[i].params['f_0']
        Q_es[i] = fit_results[i].params['Q_e_mag']
        Q_is[i] = fit_results[i].params['Q_i']

    return f_0s, Q_es, Q_is


#%%
# function for get the isolation and gyration data

def get_isolation_and_gyration(fn_dict, sa_fn_dict, fn_background, DAC, params, gammas, plot=False):
    # get data from file
    fn = fn_dict[DAC]
    sa_fn = sa_fn_dict[DAC]

    sa_freq, sa_phase, sa_trace = load_sa_data(sa_fn)

    # get data from file
    frequency, phase, normalized_response, sa_results = get_isolation_gyration_data(fn, fn_background, sa_fn)
    phase_index1 = params['phase_index1']
    phase_index2 = params['phase_index2']

    normalized_index1_data = normalized_response[phase_index1, :]
    normalized_index2_data = normalized_response[phase_index2, :]

    # set parameters for calculation
    g_pump = params['extened_g1']
    g_pump_ratio = params['g_pump_ratio']
    phase_offset = params['phase_offset']
    pump_phase_offset = params['pump_phase_offset']
    f_center = params['f_center']
    g_c_ratio = params['g_c_ratio']
    phi0_ratio = params['phi0_ratio']
    freq_detune = params['freq_detune']
    amp_offset = params['amp_offset']

    # calculate theory result
    print(gammas)
    deltas, pump_phases, results_pass, \
    results_block, calculate_results = isolation_gyration_calculation(g_pump, g_pump_ratio, g_c_ratio, phi0_ratio,
                                                                      freq_detune, gammas)

    sa_calculate_results = np.zeros(len(pump_phases))
    probe_index = int(len(deltas) / 2)
    for i in range(len(pump_phases)):
        sa_calculate_results[i] = np.abs(calculate_results[probe_index, i, 3])

    # print the isolation and insertion loss
    block_min = 10 * np.log10(np.abs(results_block[:, 4]).min() ** 2)
    pass_min = 10 * np.log10(np.abs(results_pass[:, 4]).min() ** 2)
    print(f"The isolation at this bias point is {block_min} dB")
    print(f"The insertion loss at this bias point is {pass_min} dB")

    calculate_result_dict = {'deltas': deltas, 'results_pass': results_pass[:, 4] * np.exp(1j * phase_offset),
                             'results_block': results_block[:, 4] * np.exp(1j * phase_offset)}
    # plot the result
    if plot:
        setup_figure()

        plot_pass_block_data(frequency[0] / 1e6, normalized_index1_data * np.exp(-0j * np.pi), frequency[0] / 1e6,
                             normalized_index2_data * np.exp(-0j * np.pi), f_center,
                             calculate_results=calculate_result_dict)

        plot_gyration_data(phase, sa_results, pump_phases, pump_phase_offset, sa_calculate_results, amp_offset)

        plot_full_isolation_data(frequency[0] / 1e6, f_center, phase, normalized_response,
                                 deltas, pump_phases, pump_phase_offset, calculate_results)

    return frequency, normalized_index1_data, normalized_index2_data, calculate_result_dict, phase, sa_results, pump_phases, sa_calculate_results


#%%
def process_isolation_gyration_data(fn_dict, fn_background, sa_fn_dict, params, gammas, plot=False):
    DAC = params['DAC']

    fn = fn_dict[DAC]
    sa_fn = sa_fn_dict[DAC]

    sa_freq, sa_phase, sa_trace = load_sa_data(sa_fn)

    # get data from file
    frequency, phase, normalized_response, sa_results = get_isolation_gyration_data(fn, fn_background, sa_fn)
    phase_index1 = params['phase_index1']  # 29
    phase_index2 = params['phase_index2']  # 11
    normalized_index1_data = normalized_response[phase_index1, :]
    normalized_index2_data = normalized_response[phase_index2, :]

    # set parameters for calculation
    extended_g1 = params['extended_g1']
    g_pump = extended_g1[DAC]  # 0.55*1900/1500
    g_pump_ratio = 1.0
    phase_offset = params['phase_offset']  # 1.03*np.pi
    pump_phase_offset = params['pump_phase_offset']  # 200 #205
    f_center = params['f_center']  # 4875.2
    g_c_ratio = params['g_c_ratio']  # 1.04
    phi0_ratio = params['phi0_ratio']  # 1.06
    freq_detune = params['freq_detune']  # -0.0
    amp_offset = params['amp_offset']  # -94

    # calculate theory result
    deltas, pump_phases, results_pass, \
    results_block, calculate_results = isolation_gyration_calculation(g_pump, g_pump_ratio, g_c_ratio, phi0_ratio,
                                                                      freq_detune, gammas)

    sa_calculate_results = np.zeros(len(pump_phases))
    probe_index = int(len(deltas) / 2)
    for i in range(len(pump_phases)):
        sa_calculate_results[i] = np.abs(calculate_results[probe_index, i, 3])

    # print the isolation and insertion loss
    block_min = 10 * np.log10(np.abs(results_block[:, 4]).min() ** 2)
    pass_min = 10 * np.log10(np.abs(results_pass[:, 4]).min() ** 2)
    print(f"The isolation at this bias point is {block_min} dB")
    print(f"The insertion loss at this bias point is {pass_min} dB")

    # plot the result
    setup_figure()
    calculate_result_dict = {'deltas': deltas, 'results_pass': results_pass[:, 4] * np.exp(1j * phase_offset),
                             'results_block': results_block[:, 4] * np.exp(1j * phase_offset)}
    #     isolation_result_dict = {'frq': frequency[0]/1e6, 'pass_data':  normalized_index1_data, 'block_data': normalized_index2_data}
    #     gyration_result_dict = {'phase': phase, }

    if plot:
        plot_pass_block_data(frequency[0] / 1e6, normalized_index1_data * np.exp(-0j * np.pi), frequency[0] / 1e6,
                             normalized_index2_data * np.exp(-0j * np.pi), f_center,
                             calculate_results=calculate_result_dict)

        plot_gyration_data(phase, sa_results, pump_phases, pump_phase_offset, sa_calculate_results, amp_offset)

        plot_full_isolation_data(frequency[0] / 1e6, f_center, phase, normalized_response,
                                 deltas, pump_phases, pump_phase_offset, calculate_results)

    return frequency, normalized_index1_data, normalized_index2_data, calculate_result_dict, phase, sa_results, pump_phases, sa_calculate_results


#%%
# rearange phase
def shfit_data_phase(data, phase, pump_phase_offset):
    new_data = np.zeros(data.shape)

    for i in range(len(phase)):
        temp_index = np.where(phase == np.mod((-180 + i * 10 + pump_phase_offset), 360))
        new_data[i] = data[temp_index]

    return new_data